package org.newcih.train;


import org.newcih.util.FileUtils;
import org.newcih.util.ImageUtils;
import org.newcih.util.StreamGobbler;
import org.springframework.util.StringUtils;

import java.io.File;
import java.io.FilenameFilter;
import java.util.*;


/**
 * 训练对象
 *
 * @author newcih
 * @version 2018-01-29
 */
public class TessTrainer {

    /**
     * 生成Box文件
     */
    public static final String CMD_MAKE_BOX             =
            "tesseract IMAGE_FILE_NAME BOX_FILE_NAME_WITHOUT_EXTRACT batch.nochop makebox";
    /**
     * 生成训练文件
     */
    public static final String CMD_TESS_TRAIN           =
            "tesseract IMAGE_FILE_NAME BOX_FILE_NAME_WITHOUT_EXTRACT -psm 7 nobatch box.train";
    /**
     * 生成字符集文件
     */
    public static final String CMD_UNICHARSET_EXTRACTOR = "unicharset_extractor";
    /**
     * 生成shape文件
     */
    public static final String CMD_SHAPE_CLUSTERING     = "shapeclustering -F %s.font_properties -U unicharset";
    /**
     * 生成聚集字符特征文件
     */
    public static final String CMD_MF_TRAINING          = "mftraining -F %1$s.font_properties -U unicharset";
    /**
     * 生成字符正常化特征文件
     */
    public static final String CMD_CN_TRAINING          = "cntraining";
    /**
     * 合并训练文件
     */
    public static final String CMD_COMBINE_TESSDATA     = "combine_tessdata %s.";
    /**
     * 识别图片
     */
    public static final String CMD_TESSERACT            = "tesseract IMAGE_FILE_NAME -l FONT_NAME OUTPUT_FILE_NAME";
    /**
     * 列出所有字体
     */
    public static final String CMD_LIST_LANGS           = "tesseract --list-langs";

    ProcessBuilder pb;
    /**
     * tesseract及其工具程序所在路径
     */
    String         tessHome;
    /**
     * tiff图片所在路径
     */
    String         inputDataDir;
    /**
     * 字体名
     */
    String         font;
    /**
     * 字体数据所在路径
     */
    String         tessData;

    /**
     * 构造方法，识别时使用
     *
     * @param tessHome     bin执行路径
     * @param inputDataDir 工作目录，存放图像的路径
     * @param font         字体名
     */
    public TessTrainer(String tessHome, String inputDataDir, String font) {
        pb = new ProcessBuilder();
        pb.directory(new File(inputDataDir));
        pb.redirectErrorStream(true);
        this.tessHome = tessHome;
        this.inputDataDir = inputDataDir;
        this.font = font;
    }

    /**
     * 构造方法，训练时使用
     *
     * @param tessHome     bin执行路径
     * @param inputDataDir 工作目录，存放图像的路径
     * @param font         字体名
     * @param tessData     存放字体数据的路径
     */
    public TessTrainer(String tessHome, String inputDataDir, String font, String tessData) {
        pb = new ProcessBuilder();
        pb.directory(new File(inputDataDir));
        pb.redirectErrorStream(true);
        this.tessHome = tessHome;
        this.inputDataDir = inputDataDir;
        this.font = font;
        this.tessData = tessData;
    }

    /**
     * 获取可执行命令行
     *
     * @param cmdStr
     * @return
     */
    public List<String> getCommand(String cmdStr) {
        List<String> paramList = Arrays.asList(cmdStr.split("\\s+"));
        List<String> cmd       = new LinkedList<>(paramList);
        cmd.set(0, FileUtils.composePath(tessHome) + cmd.get(0));
        return cmd;
    }

    /**
     * 执行CMD命令
     *
     * @param cmd
     * @return 输出信息
     * @throws Exception
     */
    public String runCommand(List<String> cmd) throws Exception {

        StringBuilder cmdLine = new StringBuilder(100);
        for (String c : cmd) {
            cmdLine.append(c).append(" ");
        }
        System.out.println("即将执行 " + cmdLine.toString());
        pb.command(cmd);
        Process process = pb.start();

        // 输出
        StreamGobbler outputGobbler = new StreamGobbler(process.getInputStream());
        outputGobbler.start();

        int w = process.waitFor();

        System.out.println(outputGobbler.getMessage());

        if (w != 0) {
            if (cmd.get(0).contains("shapeclustering")) {
                System.err.println("font_properties文件出错!");
            } else {
                System.err.println(outputGobbler.getMessage());
            }
            throw new RuntimeException(outputGobbler.getMessage());
        }

        return outputGobbler.getMessage();
    }

    /**
     * 图片识别接口
     *
     * @param imageFileName 图片名
     * @return
     */
    public String tesseract(String imageFileName) throws Exception {
        List<String> cmd        = getCommand(CMD_TESSERACT);
        String       outputName = UUID.randomUUID().toString();
        cmd.set(1, imageFileName);
        cmd.set(3, font);
        cmd.set(4, outputName);

        runCommand(cmd);

        // 读取结果
        File   outputFile = new File(inputDataDir, outputName + ".txt");
        String result     = FileUtils.readTextFile(outputFile).replaceAll("\n", "");
        outputFile.deleteOnExit();
        return result;
    }

    /**
     * 列出所有字体
     *
     * @return
     */
    public List<String> listFonts() throws Exception {
        List<String> cmd      = this.getCommand(CMD_LIST_LANGS);
        String       result   = this.runCommand(cmd);
        String[]     fonts    = result.split("\n");
        List<String> fontList = new ArrayList<>();
        for (int i = 0; i < fonts.length; i++) {
            if (!StringUtils.isEmpty(fonts[i])) {
                fontList.add(fonts[i]);
            }
        }
        // 第一个数据并不是字体
        fontList.remove(0);
        return fontList;
    }

    /**
     * 使用Tesseract工具生成box文件
     */
    public void makeBox(String imageFileName) throws Exception {
        List<String> cmd = getCommand(CMD_MAKE_BOX);
        cmd.set(1, imageFileName);
        cmd.set(2, font);
        runCommand(cmd);
    }

    /**
     * 生成训练数据文件
     */
    public void generateTraineddata() throws Exception {
        String[] files = this.getImageFilesWithBox();

        if (files.length == 0) {
            throw new RuntimeException("目录下缺乏box文件和相对应的图片文件");
        }

        // 则生成font_properties文件
        try {
            File fontpropFile = new File(inputDataDir, font + ".font_properties");
            FileUtils.createFile(fontpropFile);
            FileUtils.writeTextFile(fontpropFile, font + " 0 0 0 0 0");
        } catch (Exception e) {
            e.printStackTrace();
            System.out.println("生成font_properties文件失败");
        }

        List<String> cmd = getCommand(CMD_TESS_TRAIN);
        for (String file : files) {
            cmd.set(1, file);
            cmd.set(2, ImageUtils.stripExtension(file));
            runCommand(cmd);
        }

        cmd = getCommand(CMD_UNICHARSET_EXTRACTOR);
        files = new File(inputDataDir).list(new FilenameFilter() {
            @Override
            public boolean accept(File dir, String fileName) {
                return fileName.endsWith(".box");
            }
        });
        cmd.addAll(Arrays.asList(files));
        runCommand(cmd);

        runShapeClustering();

        /**
         * 训练完成，将字体数据移动到目录下
         */
        String sourceTrainedDataFolder = FileUtils.composePath(inputDataDir, "tessdata");
        File   sourceTrainedData       = new File(sourceTrainedDataFolder + font + ".traineddata");
        String destTrainedDataFolder   = FileUtils.composePath(tessData);
        File   destTrainedData         = new File(destTrainedDataFolder + font + ".traineddata");
        org.apache.commons.io.FileUtils.moveFile(sourceTrainedData, destTrainedData);
        /**
         * 后续步骤，直接删除为训练字体而创建的工作目录
         */
    }

    /**
     * 执行shape clustering生成命令
     *
     * @throws Exception
     */
    private void runShapeClustering() throws Exception {
        String[] files = new File(inputDataDir).list(new FilenameFilter() {
            @Override
            public boolean accept(File dir, String fileName) {
                return fileName.endsWith(".tr");
            }
        });

        if (files.length == 0) {
            throw new RuntimeException("找不到.tr文件，必须先生成该文件");
        }

        List<String> cmd = getCommand(String.format(CMD_SHAPE_CLUSTERING, font));
        cmd.addAll(Arrays.asList(files));
        runCommand(cmd);

        cmd = getCommand(String.format(CMD_MF_TRAINING, font));
        cmd.addAll(Arrays.asList(files));
        runCommand(cmd);

        cmd = getCommand(CMD_CN_TRAINING);
        cmd.addAll(Arrays.asList(files));
        runCommand(cmd);

        renameFile(inputDataDir, "inttemp", font + ".inttemp");
        renameFile(inputDataDir, "pffmtable", font + ".pffmtable");
        renameFile(inputDataDir, "normproto", font + ".normproto");
        renameFile(inputDataDir, "shapetable", font + ".shapetable");
        renameFile(inputDataDir, "unicharset", font + ".unicharset");

        runDictionary();
    }

    /**
     * 从字典训练
     */
    private void runDictionary() throws Exception {
        if (!new File(inputDataDir, font + ".unicharset").exists()) {
            String msg = String.format("目录下不存在%1$s.unicharset文件，需要先生成该文件", font);
            throw new RuntimeException(msg);
        }

        List<String> cmd = getCommand(String.format(CMD_COMBINE_TESSDATA, font));
        runCommand(cmd);

        String traineddata = font + ".traineddata";
        File   tessdata    = new File(inputDataDir, "tessdata");
        if (!tessdata.exists()) {
            tessdata.mkdirs();
        }
        new File(inputDataDir, traineddata).renameTo(new File(tessdata, traineddata));
    }

    /**
     * 同目录下文件更名
     *
     * @param dir         新旧文件所在目录
     * @param oldFileName 旧文件名
     * @param newFileName 新文件名
     */
    public static void renameFile(String dir, String oldFileName, String newFileName) {
        File oldFile = new File(FileUtils.composePath(dir) + oldFileName);
        File newFile = new File(FileUtils.composePath(dir) + newFileName);
        oldFile.renameTo(newFile);
    }

    /**
     * 获取待训练的图片文件
     *
     * @return
     */
    private String[] getImageFiles() {
        String[] files = new File(inputDataDir).list(new FilenameFilter() {
            @Override
            public boolean accept(File dir, String fileName) {
                return fileName.toLowerCase().matches(".*\\.(tif|tiff|jpg|jpeg|png|bmp)$");
            }
        });

        return files;
    }

    /**
     * 获取待训练的图片数据，前提是已生成box文件
     *
     * @return
     */
    private String[] getImageFilesWithBox() {
        List<String> filesWithBox = new ArrayList<>();
        for (String file : getImageFiles()) {
            String withoutExt = ImageUtils.stripExtension(file);
            if (new File(inputDataDir, withoutExt + ".box").exists()) {
                filesWithBox.add(file);
            }
        }

        return filesWithBox.toArray(new String[0]);
    }

    public static void main(String[] args) throws Exception {
    }

}
